'''
Description: draw boxes of char,word,line,region 
Author: KouXichao
Date: 2020-10-30 13:18:03
LastEditors: KouXichao
LastEditTime: 2021-03-25 11:50:54
FilePath: /LayoutGen/tools/procfile.py
Github: https://github.com/kouxichao
'''
#coding=utf-8
import sys
import os
import numpy as np
sys.path.append(os.path.abspath('./'))
from PIL import Image, ImageDraw, ImageFont
from eval_rec_acc import SynthLayout
import cv2
import config
import shutil
import json
import xml.etree.ElementTree as ET

def mergeRects(rects):
    point_array = np.array(rects, dtype = int) # x0 y0 x1 y1 x2 y2 x3 y3
    max_position = np.argmax(point_array, axis = 0) # 按列获得每列最大值所在行的索引
    min_position = np.argmin(point_array, axis = 0)
    x_max = max([point_array[max_position[0]][0],point_array[max_position[2]][2],point_array[max_position[4]][4],point_array[max_position[6]][6]])
    x_min = min([point_array[min_position[0]][0],point_array[min_position[2]][2],point_array[min_position[4]][4],point_array[min_position[6]][6]])
    y_max = max([point_array[max_position[1]][1],point_array[max_position[3]][3],point_array[max_position[5]][5],point_array[max_position[7]][7]])
    y_min = min([point_array[min_position[1]][1],point_array[min_position[3]][3],point_array[min_position[5]][5],point_array[min_position[7]][7]])

    merge_rect = [x_min, y_min, x_max, y_min, x_max, y_max, x_min, y_max]
    return merge_rect

import torch.utils.data
class COCODataset(torch.utils.data.Dataset):
    class_names = ('text',
                    'title',  
                    'list', 'table', 'figure')

    def __init__(self, data_dir, ann_file, transform=None, target_transform=None, remove_empty=False):
        from pycocotools.coco import COCO
        self.coco = COCO(ann_file)
        self.data_dir = data_dir
        self.transform = transform
        self.target_transform = target_transform
        self.remove_empty = remove_empty
        if self.remove_empty:
            # when training, images without annotations are removed.
            self.ids = list(self.coco.imgToAnns.keys())
        else:
            # when testing, all images used.
            self.ids = list(self.coco.imgs.keys())
        coco_categories = sorted(self.coco.getCatIds())
        self.coco_id_to_contiguous_id = {coco_id: i + 1 for i, coco_id in enumerate(coco_categories)}
        self.contiguous_id_to_coco_id = {v: k for k, v in self.coco_id_to_contiguous_id.items()}

    def __getitem__(self, index):
        image_id = self.ids[index]
        boxes, labels = self._get_annotation(image_id)
        image = self._read_image(image_id)
        file_name = self.coco.loadImgs(image_id)[0]['file_name']
        if self.transform:
            image, boxes, labels = self.transform(image, boxes, labels)
        if self.target_transform:
            boxes, labels = self.target_transform(boxes, labels)
        return image, boxes, labels, file_name

    def get_annotation(self, index):
        image_id = self.ids[index]
        return image_id, self._get_annotation(image_id)

    def __len__(self):
        # print(len(self.ids))
        return len(self.ids)

    def _get_annotation(self, image_id):
        ann_ids = self.coco.getAnnIds(imgIds=image_id)
        ann = self.coco.loadAnns(ann_ids)
        # filter crowd annotations
        ann = [obj for obj in ann if obj["iscrowd"] == 0]
        boxes = np.array([self._xywh2xyxy(obj["bbox"]) for obj in ann], np.float32).reshape((-1, 4))
        labels = np.array([self.coco_id_to_contiguous_id[obj["category_id"]] for obj in ann], np.int64).reshape((-1,))
        # remove invalid boxes
        keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
        boxes = boxes[keep]
        labels = labels[keep]
        return boxes, labels

    def _xywh2xyxy(self, box):
        x1, y1, w, h = box
        return [x1, y1, x1 + w, y1 + h]

    def get_img_info(self, index):
        # print("********", index, "**************")
        image_id = self.ids[index]
        img_data = self.coco.imgs[image_id]
        return img_data

    def _read_image(self, image_id):
        file_name = self.coco.loadImgs(image_id)[0]['file_name']
        image_file = os.path.join(self.data_dir, file_name)
        image = Image.open(image_file).convert("RGB")
        image = np.array(image)
        return image

def rm_mkdir(dataPath, name):
    pth = os.path.join(dataPath, name)
    if os.path.exists(pth):  
        shutil.rmtree(pth)
    if not os.path.exists(pth): 
        os.mkdir(pth) 
            
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description="!!!!!!!!!!!!Draw rects for text and other contents!!!!!!!!!!!!!")
    parser.add_argument("-dp", "--dataPath", default='./synthetic_data/', type=str, help="Layout data folder")
    parser.add_argument("-dt", "--draw_type", default=15, type=int, help="Type of rects to draw")
    parser.add_argument("-w", "--warp", default=False, action='store_true', help="If there are warp images")
    parser.add_argument("-rm", "--remove", default=False, action='store_true', help="Remove procfile_image")
    parser.add_argument("-si", "--save_region_img", default=False, action='store_true', help="Save image of region")
    parser.add_argument('-df', "--data_format", default="LayoutGen", type=str, help="Data type")
    parser.add_argument('-tn', "--text_name", default="train.txt", type=str, help="Text name of data, this options is valid when option data_format is xml")
    opt = parser.parse_args()

    if opt.data_format == 'xml':
        voc_categories = ['figure','formula']
        with open(os.path.join(opt.dataPath, opt.text_name), 'r') as f:
            imagenames = f.read().splitlines()
        nums = len(imagenames)
    elif opt.data_format == 'json':
        dataloader = COCODataset(os.path.join(opt.dataPath, 'val'), os.path.join(opt.dataPath, 'val.json'))
        nums = 100 #dataloader.__len__()
        coco_categories = dataloader.class_names

    else:    
        dataloader = SynthLayout(opt.dataPath)
        nums = dataloader.__len__()

    pth = os.path.join(opt.dataPath, 'procfile_image/')
    if opt.remove and os.path.exists(pth):  
        shutil.rmtree(pth)
    if not os.path.exists(pth): 
        os.mkdir(pth)  

    if opt.save_region_img:
        region_pth = os.path.join(opt.dataPath, 'clip_region/')
        if opt.remove and os.path.exists(region_pth):  
            shutil.rmtree(region_pth)
        if not os.path.exists(region_pth): 
            os.mkdir(region_pth)  

        # region count
        conut_dict = {}
        if opt.data_format == 'xml': 
            categories = voc_categories
        else:
            categories = coco_categories
        for i in categories:
            conut_dict[i] = 0
            rm_mkdir(os.path.join(opt.dataPath, "clip_region"), i)
        
    for i in range(nums):
        if opt.data_format == 'json':
            image, region_boxes, labels, imageName = dataloader.__getitem__(i)
            imagePath = os.path.join(opt.dataPath, "val/" + imageName)
            if '.' in imageName:
                imageName_ex = imageName
                imageName = imageName.split('.')[0]
            else:
                imageName_ex = imageName + '.jpg'
                # print("{},{},{},{}".format(image, region_boxes, labels, imageName))
        elif opt.data_format == 'xml':
            imageName = imagenames[i]
            if '.' in imageName:
                imageName_ex = imageName
                imageName = imageName_ex.split('.')[0]
            else:
                imageName_ex = imageName + '.jpg'
            imagePath = os.path.join(opt.dataPath, "image/" + imageName_ex)
        else:
            image, char_bboxes, word_bboxes, line_bboxes, imageName = dataloader.load_image_gt(i)
            imageName_ex = imageName + '_warp.jpg' if opt.warp else imageName + '.jpg'
            imagePath = os.path.join(opt.dataPath, "image/" + imageName_ex)
        im = Image.open(imagePath)
        dr = ImageDraw.Draw(im)
        textsize = 30
        ft = ImageFont.truetype("./font/cn/simsun.ttf", textsize)

        if opt.data_format != 'json' and opt.data_format != 'xml':
            if opt.draw_type & 0x01:
                # 显示字符框
                for i in range(np.array(char_bboxes).shape[0]):
                    # line_b = [] 
                    # for wb in char_bboxes[i]:
                    #     line_b.extend(wb)
                       
                    # lb = mergeRects(line_b)
                    # dr.rectangle([((float)(lb[0]), (float)(lb[1])), ((float)(lb[4]), (float)(lb[5]))], outline='green', width=5)
                    for wb in char_bboxes[i]:               # 行循环
                        for cb in wb:                       # 词循环
                            cb = np.array(cb).reshape((-1)) 
                            dr.rectangle([((float)(cb[0]), (float)(cb[1])), ((float)(cb[4]), (float)(cb[5]))], outline='green', width=3)
                        #     dr.polygon([((float)(cb[0]), (float)(cb[1])), \
                        #         ((float)(cb[2]), (float)(cb[3])), ((float)(cb[4]), (float)(cb[5])), \
                        #             ((float)(cb[6]), (float)(cb[7]))], outline='green')
                        # # if len(wb) > 0:
                        #     wb = mergeRects(wb)
                        #     dr.polygon([((float)(wb[0]), (float)(wb[1])), \
                        #         ((float)(wb[2]), (float)(wb[3])), ((float)(wb[4]), (float)(wb[5])), \
                        #             ((float)(wb[6]), (float)(wb[7]))], outline='green')

            if opt.draw_type & 0x02:  
                # 显示词框        
                for i in range(np.array(word_bboxes).shape[0]):       
                    text = ""
                    for wb in word_bboxes[i]:               # 行循环
                        text += wb[-1]
                        wb = np.array(wb[:-1]).reshape((-1)) 
                        dr.rectangle([((float)(wb[0]), (float)(wb[1])), ((float)(wb[4]), (float)(wb[5]))], outline='green', width=5)
                        # dr.polygon([((float)(wb[0]), (float)(wb[1])), \
                        #     ((float)(wb[2]), (float)(wb[3])), ((float)(wb[4]), (float)(wb[5])), \
                        #         ((float)(wb[6]), (float)(wb[7]))], outline='green')      
                    # print("\033[31mtext:\033[0m", text.strip())
            
            if opt.draw_type & 0x04:
                #显示文本行框    
                for i in range(np.array(line_bboxes).shape[0]):
                    arr = np.array(line_bboxes[i]).reshape((-1))  
                    dr.rectangle([((float)(arr[0]), (float)(arr[1])), ((float)(arr[4]), (float)(arr[5]))], outline='blue', width=2)
                #     dr.polygon([((float)(arr[0]), (float)(arr[1])), \
                #         ((float)(arr[2]), (float)(arr[3])), ((float)(arr[4]), (float)(arr[5])), \
                #             ((float)(arr[6]), (float)(arr[7]))], outline='green')  

        # 显示区域框 和 区域类型(json)
        if opt.draw_type & 0x08:
            print("json ", imageName)
            if opt.data_format == 'json':
                for i, l in enumerate(labels):
                    rect = region_boxes[i]
                    x1,y1,x2,y2=rect[0], rect[1], rect[2], rect[3]
                    # input()
                    label = l      
                    obj_type = dataloader.class_names[l-1]

                    if opt.save_region_img:
                        box = (max(0,x1-1), max(0, y1-1), min(im.size[0]-1, x2+1), min(im.size[1]-1, y2+1))
                        # box = (x1,y1,x2,y2)
                        region = im.crop(box)
                        idx = str(conut_dict[obj_type])
                        conut_dict[obj_type] = conut_dict[obj_type] + 1           
                        region.save(os.path.join(os.path.join(opt.dataPath, 'clip_region'),  obj_type + '/' + imageName+'_'+obj_type+idx+'.jpg'))   
                    if 'text' in obj_type:
                        dr.rectangle(rp, outline='green', width=3) 
                        dr.text(b[0], '文本', font=ft, fill='green')
                    if 'title' in label:
                        dr.rectangle(rp, outline='gray', width=3) 
                        dr.text(b[0], '标题', font=ft, fill='gray')
                    if 'table' in label:
                        dr.rectangle(rp, outline='red', width=3) 
                        dr.text(b[0], '表格', font=ft, fill='red')
                    if 'figure' in label:
                        dr.rectangle(rp, outline='blue', width=3)
                        dr.text(b[0], '图像', font=ft, fill='blue')
                    if 'list' in label:
                        dr.rectangle(rp, outline='purple', width=3)
                        dr.text(b[0], '列表', font=ft, fill='purple')                      
                    # dr.polygon([(x1,y1), (x2, y1), (x2, y2), (x1, y2)], outline='black')        
                    # # dr.text(np.array([[x1,y1], [x2, y1], [x2, y2], [x1, y2]]).reshape(-1, 2).mean(axis=0).tolist(), "json_"+obj_type)#, (255,255,0))  
                    # dr.text(np.array([[x1,y1], [x2, y1], [x2, y2], [x1, y2]]).reshape(-1, 2).mean(axis=0).tolist(), obj_type, font=ft)#, (255,255,0))  

            else:
                f = open(os.path.join(opt.dataPath, './json/' + imageName + '.json'),encoding="utf-8")
                region_dict = json.load(f)
                line_width = 7
                for region in region_dict["shapes"]:
                    label = region["label"]
                    rect = region["points"]
                    # print(type(rect))
                    # dr.polygon(np.array(rect), outline='yellow')
                    b = []
                    for i in rect:
                        b.append((i[0], i[1]))
                    # dr.polygon([tuple(rect[0]), tuple(rect[1]), tuple(rect[2]), tuple(rect[3])], outline='yellow') 
                    rp = [b[0],b[-2]]
                    if 'text' in label:
                        dr.rectangle(rp, outline='green', width=line_width) 
                        # dr.text(b[0], label.split('_')[0], font=ft, fill='green')
                    if 'title' in label:
                        dr.rectangle(rp, outline='gray', width=line_width) 
                        # dr.text(b[0], label.split('_')[0], font=ft, fill='gray')
                    if 'table' in label:
                        dr.rectangle(rp, outline='red', width=line_width) 
                        # dr.text(b[0], label.split('_')[0], font=ft, fill='red')
                    if 'image' in label:
                        dr.rectangle(rp, outline='blue', width=line_width)
                        # dr.text(b[0], label.split('_')[0], font=ft, fill='blue')
                    if 'formula' in label:
                        dr.rectangle(rp, outline='purple', width=line_width)
                        # dr.text(b[0], label.split('_')[0], font=ft, fill='purple')        
                    # dr.text(np.array(rect).reshape(-1, 2).mean(axis=0).tolist(), "json_"+label, (255,255,0))
                    # dr.text(b[0], label.split('_')[0], font=ft, fill='white')

        # 显示区域框 和 区域类型(xml)
        if opt.draw_type & 0x10:
            print("xml ", imageName)
            filename = os.path.join(opt.dataPath, './xml/' + imageName + '.xml')
            tree = ET.parse(filename)
            objs = tree.findall('object') 
            num_objs = len(objs)

            boxes = np.zeros((num_objs, 4), dtype=np.uint16)
            gt_classes = np.zeros((num_objs), dtype=np.int32)

            # Load object bounding boxes into a data frame.
            for ix, obj in enumerate(objs):
                # if obj.find('name').text.lower().strip() != 'table':
                #     continue
                bbox = obj.find('bndbox')
                # Make pixel indexes 0-based
                x1 = float(bbox.find('xmin').text)  # - 1
                y1 = float(bbox.find('ymin').text)  # - 1
                x2 = float(bbox.find('xmax').text)  # - 1
                y2 = float(bbox.find('ymax').text)  # - 1
                obj_type = obj.find('name').text.lower().strip()
                if opt.save_region_img:
                    box = (max(0,x1-1), max(0, y1-1), min(im.size[0]-1, x2+1), min(im.size[1]-1, y2+1))
                    # box = (x1,y1,x2,y2)
                    region = im.crop(box)
                    idx = str(conut_dict[obj_type])
                    conut_dict[obj_type] = conut_dict[obj_type] + 1                
                    region.save(os.path.join(os.path.join(opt.dataPath, 'clip_region'),  obj_type + '/' + imageName+'_'+obj_type+idx+'.jpg'))

                rp = [(x1,y1),(x2,y2)]
                if 'table' in obj_type:
                    dr.rectangle(rp, outline='red', width=3) 
                    dr.text((x1,y1), '表格', font=ft, fill='red')
                elif 'figure' in obj_type:
                    dr.rectangle(rp, outline='blue', width=3)
                    dr.text((x1,y1), '图像', font=ft, fill='blue')
                elif 'equation' in obj_type:
                    dr.rectangle(rp, outline='green', width=3)
                    dr.text((x1,y1), '公式', font=ft, fill='green')   
                else:
                    print("沒有此类型 %s"%(obj_type))   
                    input()           
                # dr.polygon([(x1,y1), (x2, y1), (x2, y2), (x1, y2)], outline='yellow')        
                # print(np.array([[x1,y1], [x2, y1], [x2, y2], [x1, y2]]).reshape(-1, 2).mean(axis=0).tolist(), ' ', obj_type)
                # dr.text(np.array([[x1,y1], [x2, y1], [x2, y2], [x1, y2]]).reshape(-1, 2).mean(axis=0).tolist(), 'xml_'+obj_type, 1)#, (255,0,0))
 
        im.save(os.path.join(pth, imageName_ex))   # 保存绘制图像
